import asyncio
from typing import Optional
from contextlib import AsyncExitStack

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client

from openai import OpenAI
import os
import json
from dotenv import load_dotenv

load_dotenv()  # load environment variables from .env
api_key = os.getenv('api_key', '')
base_url= os.getenv('base_url', '')
modelname=os.getenv('modelname','')
class MCPClient:
    def __init__(self, api_key=None, base_url=None):
        # Initialize session and client objects
        self.session: Optional[ClientSession] = None
        self.exit_stack = AsyncExitStack()
        
        # Initialize OpenAI client with explicit parameters
        self.openai = OpenAI(
            api_key=api_key ,
            base_url=base_url ,
        )
        self.deepseek_client = OpenAI(api_key=api_key,base_url=base_url) 

    async def connect_to_server(self, server_script_path: str):
        """Connect to an MCP server
        
        Args:
            server_script_path: Path to the server script (.py or .js)
        """
        is_python = server_script_path.endswith('.py')
        is_js = server_script_path.endswith('.js')
        if not (is_python or is_js):
            raise ValueError("Server script must be a .py or .js file")
            
        command = "python" if is_python else "node"
        server_params = StdioServerParameters(
            command=command,
            args=[server_script_path],
            env=None
        )
        
        stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
        self.stdio, self.write = stdio_transport
        self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
        
        await self.session.initialize()
        
        # List available tools
        response = await self.session.list_tools()
        tools = response.tools
        print("\nConnected to server with tools:", [tool.name for tool in tools])

    async def connect_to_sse_server(self, server_url: str, timeout=10.0, retries=1):
        """Connect to an MCP server running with SSE transport"""
        try:
            # Store the context managers so they stay alive
            self._streams_context = sse_client(url=server_url)
            streams = await self._streams_context.__aenter__()

            self._session_context = ClientSession(*streams)
            self.session: ClientSession = await self._session_context.__aenter__()

            # Initialize
            await self.session.initialize()

            # List available tools to verify connection
            print("Initialized SSE client...")
            print("Listing tools...")
            response = await self.session.list_tools()
            tools = response.tools
            print("\nConnected to server with tools:", [tool.name for tool in tools])
            return True

        except Exception as e:
            print(f"Connection error: {str(e)}")
            if self.debug:
                import traceback
                traceback.print_exc()
            return False

    async def process_query(self, query: str) -> str:
        """Process a query using OpenAI and available tools"""
        messages = [
            {
                "role": "user",
                "content": query
            }
        ]

        response = await self.session.list_tools()
        available_tools = [{ 
            "type": "function",
            "function": {
                "name": tool.name,
                "description": tool.description,
                "parameters": tool.inputSchema
            }
        } for tool in response.tools]
        print(f"available_tools: {available_tools}")
        # Initial OpenAI API call
        #response = self.openai.chat.completions.create(
        #    model="deepseek-chat",  # Use appropriate OpenAI model
        #    messages=messages,
        #    tools=available_tools,
        #    tool_choice="auto"
        #)
        response = self.deepseek_client.chat.completions.create(
                model=modelname,  # deepseek官网用的是deepseek-chat
                messages=messages,
                tools=available_tools,
                tool_choice="auto",
                temperature=0
            )
        print(response)
        # Process response and handle tool calls
        final_text = []
        assistant_message = response.choices[0].message
        print(f"assistant_message: {assistant_message}")
        final_text.append(assistant_message.content or "")
        #print (f"assistant_message: {assistant_message}")
        print (f"assistant_message.tool_calls: {assistant_message.tool_calls}")
        # Handle tool calls if present
        if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls:
            for tool_call in assistant_message.tool_calls:
                tool_name = tool_call.function.name
                tool_args = tool_call.function.arguments
                print(f"tool_name: {tool_name}")
                print(f"tool_args: {tool_args}")
                if isinstance(tool_args, str):
                    try:
                        # 尝试解析JSON字符串成字典
                        tool_args = json.loads(tool_args)
                    except json.JSONDecodeError:
                        # 如果无法解析为JSON，创建一个包含原始字符串的字典
                        tool_args = {"input": tool_args}
        
                # 再次确认类型，如果仍然不是字典，则创建一个包含它的字典
                if not isinstance(tool_args, dict):
                    tool_args = {"value": str(tool_args)}
                    
                print(f"converted tool_args: {tool_args}")
                # Convert string arguments to JSON if needed
                
                
                # Execute tool call
                result = await self.session.call_tool(tool_name, tool_args)
                print(f"result: {result}")
                final_text.append(f"[Calling tool {tool_name} with args {tool_args}]")
                result_content = result.content
                if isinstance(result_content, list):
                    # 如果是列表(可能包含TextContent对象)，提取文本
                    result_content = result_content[0].text if hasattr(result_content[0], 'text') else str(result_content)
                elif not isinstance(result_content, str):
                    # 如果不是字符串也不是列表，转换为字符串
                    result_content = str(result_content)
                # Continue conversation with tool results
                messages.append(assistant_message)
                messages.append({
                    "role": "tool", 
                    "tool_call_id": tool_call.id,
                    "name": tool_name,
                    "content": result_content
                })
                print(f"messages: {messages}")
                # Get next response from OpenAI
                response = self.openai.chat.completions.create(
                    model="deepseek-chat",  # Use appropriate OpenAI model
                    messages=messages,
                )
                print(f"response: {response}")
                final_text.append(response.choices[0].message.content or "")

        return "\n".join(final_text)

    async def chat_loop(self):
        """Run an interactive chat loop"""
        print("\nMCP Client Started!")
        print("Type your queries or 'quit' to exit.")
        
        while True:
            try:
                query = input("\nQuery: ").strip()
                
                if query.lower() == 'quit':
                    break
                    
                response = await self.process_query(query)
                print("\n" + response)
                    
            except Exception as e:
                print(f"\nError: {str(e)}")
    
    async def cleanup(self):
        """Clean up resources"""
        await self.exit_stack.aclose()

async def main():
    if len(sys.argv) < 2:
        print("Usage: python dsclient.py <path_to_server_script> ")
        sys.exit(1)
        
    # Optional command line arguments for API key and base URL
    
    
    client = MCPClient(api_key=api_key, base_url=base_url)
    try:
        if sys.argv[1].endswith('.py'):
            await client.connect_to_server(sys.argv[1])
        else:
            await client.connect_to_sse_server(sys.argv[1])
        await client.chat_loop()
    finally:
        await client.cleanup()

if __name__ == "__main__":
    import sys
    asyncio.run(main())
